import com.jogamp.opencl.CLBuffer;
import com.jogamp.opencl.CLCommandQueue;
import com.jogamp.opencl.CLContext;
import com.jogamp.opencl.CLDevice;
import com.jogamp.opencl.CLKernel;
import com.jogamp.opencl.CLProgram;

import java.io.File;
import java.io.IOException;
import java.nio.FloatBuffer;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Dictionary;
import java.util.Hashtable;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartUtilities;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;

import static java.lang.System.*;
import static com.jogamp.opencl.CLMemory.Mem.*;
import static java.lang.Math.*;

/**
 * JOCL Java Alpha Kernel client example. 
 */
public class AlphaHHKernel_Tuning_Test {

	public static Random randomGenerator = new Random();

	public static float START_TIME = -30;
	public static float END_TIME = 100;

	public static int ELEM_COUNT = 302;

	public static boolean PLOTTING = true;
	public static int SAMPLES = 3;

	public static void main(String[] args) throws IOException {

		// set up (uses default CLPlatform and creates context for all devices)
		CLContext context = CLContext.create();
		out.println("created "+ context);

		try{
			// an array with available devices
			CLDevice[] devices = context.getDevices();

			for(int i=0; i<devices.length; i++)
			{
				out.println("device-" + i + ": " + devices[i]);
			}	

			// have a look at the output and select a device
			CLDevice device = devices[0];
			// ... or use this code to select fastest device
			//CLDevice device = context.getMaxFlopsDevice();

			out.println("using "+ device);
			out.println("max workgroup size: " + device.getMaxWorkGroupSize());
			out.println("max workitems size: " + device.getMaxWorkItemSizes()[0]);

			// create command queue on selected device.
			CLCommandQueue queue = device.createCommandQueue();

			/* constants declaration */
			// max conductances - no need to be buffer (same for all)
			float maxG_K = 36;
			float maxG_Na = 120;
			float maxG_Leak = (float) 0.3;
			// reverse potentials - no need to be buffers (they're the same for all)
			float E_K = -12;
			float E_Na = 115;
			float E_Leak = (float) 10.613;
			// I_ext 
			float I_ext = 10;
			// integration step
			float dt = (float) 0.01;
			// total time steps
			int steps = (int) ((int)(END_TIME - START_TIME)/dt);
			
			// load sources, create and build program
			CLProgram program = context.createProgram(AlphaHHKernel_Tuning_Test.class.getResourceAsStream("/resource/AlphaHHKernel_Tuning.cl")).build();
			
			// get a reference to the kernel function with the name 'IntegrateHHStep'
			CLKernel kernel = program.createCLKernel("IntegrateHHStep");
			out.println("max kernel workgroup size: " + (int)kernel.getWorkGroupSize(device));
			
			// NOTE: do the same for all the neurons, C. elegans has 302 neurons, even if they don't fire and this is HH (squid) it will hopefully give a first indication
			int elementCount = ELEM_COUNT;                                  // Length of arrays to process
			int localWorkSize = min((int)kernel.getWorkGroupSize(device), 256);  // Local work size dimensions
			int globalWorkSize = roundUp(localWorkSize, elementCount);   // rounded up to the nearest multiple of the localWorkSize
			int globalWorkSize_Plotting = roundUp(localWorkSize, elementCount*steps);
	
			/* input buffers declarations */
			CLBuffer<FloatBuffer> V_in_Buffer = context.createFloatBuffer(globalWorkSize, READ_WRITE);
			CLBuffer<FloatBuffer> x_n_in_Buffer = context.createFloatBuffer(globalWorkSize, READ_WRITE);
			CLBuffer<FloatBuffer> x_m_in_Buffer = context.createFloatBuffer(globalWorkSize, READ_WRITE);
			CLBuffer<FloatBuffer> x_h_in_Buffer = context.createFloatBuffer(globalWorkSize, READ_WRITE);
			CLBuffer<FloatBuffer> plot_results_Buffer = context.createFloatBuffer(globalWorkSize_Plotting, WRITE_ONLY);

			out.println("Approx. used device memory (buffers only): " + (V_in_Buffer.getCLSize()*4)/1000000 +"MB");

			// fill input buffers with initial conditions 
			// NOTE: they'll all be the same for now, but initial conditions for different neurons could be different
			initInputBuffers(V_in_Buffer.getBuffer(), x_n_in_Buffer.getBuffer(), x_m_in_Buffer.getBuffer(), x_h_in_Buffer.getBuffer());
			initResultsBuffers(plot_results_Buffer.getBuffer());

			long compuTime = nanoTime();
	
			// map the input/output buffers to its input parameters.
			kernel.putArg(maxG_K)
			.putArg(maxG_Na)
			.putArg(maxG_Leak)
			.putArg(E_K)
			.putArg(E_Na)
			.putArg(E_Leak)
			.putArg(I_ext)
			.putArg(dt)
			.putArg(steps)
			.putArgs(V_in_Buffer, x_n_in_Buffer, x_m_in_Buffer, x_h_in_Buffer, plot_results_Buffer)
			.putArg(elementCount)
			.rewind();

			// asynchronous write of data to GPU device, followed by blocking read to get the computed results back.
			queue.putWriteBuffer(V_in_Buffer, false)
			.putWriteBuffer(x_n_in_Buffer, false)
			.putWriteBuffer(x_m_in_Buffer, false)
			.putWriteBuffer(x_h_in_Buffer, false)
			.put1DRangeKernel(kernel, 0, globalWorkSize, localWorkSize)
			.putReadBuffer(plot_results_Buffer, true);

			compuTime = nanoTime() - compuTime;
			
			out.println("computation took: "+ (compuTime/1000000) +"ms");
			
			/* PLOTTING SETUP */
			if(PLOTTING)
	        {
				// some dictionary for plotting
				Hashtable<Integer, Hashtable<Float, Float>> V_by_t = new Hashtable<Integer, Hashtable<Float, Float>>();
				List<Integer> sampleIndexes = new ArrayList<Integer>();
	    		// Generate some random indexes in the 0 .. ELEM_COUNT range
	    		for(int i = 0; i < SAMPLES; i++ )
	    		{
	    			sampleIndexes.add(randomGenerator.nextInt(elementCount));
	    		}

            	// record some values for plotting
            	Iterator<Integer> itr = sampleIndexes.iterator();
            	while(itr.hasNext())
            	{   
            		Integer index = itr.next();

            		if(!V_by_t.containsKey(index))
            		{
            			V_by_t.put(index, new Hashtable<Float, Float>());
            		}
            		
            		for(int j = 0; j < steps; j++)
            		{
            			V_by_t.get(index).put(new Float(j*dt + START_TIME), plot_results_Buffer.getBuffer().get(index + j*elementCount));
            		}
            	}
            	
            	// print some sampled charts to make sure we got fine-looking results.
            	// Plot results
            	Iterator<Integer> iter = sampleIndexes.iterator();
            	while(iter.hasNext())
            	{   
            		XYSeries series = new XYSeries("HH_Graph");

            		Integer index = iter.next();

            		for (int t = 0; t < steps; t++) {
            			series.add(t*dt + START_TIME, V_by_t.get(index).get(t*dt + START_TIME));
            		}

            		// Add the series to your data set
            		XYSeriesCollection dataset = new XYSeriesCollection();
            		dataset.addSeries(series);

            		plot(dataset, index);
            	}
            }

			System.out.println("end of HH simulation");
		}finally{
			// cleanup all resources associated with this context.
			context.release();
		}

	}

    private static void initInputBuffers(FloatBuffer V_in, FloatBuffer x_n_in, FloatBuffer x_m_in, FloatBuffer x_h_in) {
        // initial condition for V is -10
    	while(V_in.remaining() != 0)
        {
            V_in.put(-10);
        }
        V_in.rewind();
        
        // initial conditions for x n/m/h
        while(x_n_in.remaining() != 0)
        {
        	x_n_in.put(0);
        }
        x_n_in.rewind();      
        while(x_m_in.remaining() != 0)
        {
        	x_m_in.put(0);
        }
        x_m_in.rewind();
        while(x_h_in.remaining() != 0)
        {
        	x_h_in.put(1);
        }
        x_h_in.rewind();
    }
    
    private static void initResultsBuffers(FloatBuffer results) {
    	while(results.remaining() != 0)
        {
    		results.put(0);
        }
    	results.rewind();
    }

    private static int roundUp(int groupSize, int globalSize) {
        int r = globalSize % groupSize;
        if (r == 0) {
            return globalSize;
        } else {
            return globalSize + groupSize - r;
        }
    }
    
    private static void plot(XYSeriesCollection dataset, int index)
    {
    	// Generate the graph
		JFreeChart chart = ChartFactory.createXYLineChart("HH Chart", "time", "Voltage", dataset, PlotOrientation.VERTICAL, true, true, false);
		try {
			ChartUtilities.saveChartAsJPEG(new File("output/HH_Chart_" + index + ".jpg"), chart, 500, 300);
		} catch (IOException e) {
			System.err.println("Problem occurred creating chart.");
		}
    }

}